{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cosine Similarity calculation between sentences with Transformers\n",
    "\n",
    "# [Link to my Youtube Video Explaining this whole Notebook](https://www.youtube.com/watch?v=fwDTLQDKJTE&list=PLxqBkZuBynVQEvXfJpq3smfuKq3AiNW-N&index=8)\n",
    "\n",
    "[![Imgur](https://imgur.com/mWUzYiE.png)](https://www.youtube.com/watch?v=fwDTLQDKJTE&list=PLxqBkZuBynVQEvXfJpq3smfuKq3AiNW-N&index=8)\n",
    "---\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Each single token (i.e. after the originals words goes through the Tokenization) > produces a single vector > which has a size 768.\n",
    "\n",
    "Here 768 are the number of columns.\n",
    "\n",
    "Those 768 values contain our numerical representation of a single token — which we can use as contextual word embeddings.\n",
    "\n",
    "And Because there is one of these vectors for representing each token (output by each encoder), for the totality, we are actually looking at a tensor of size 768 by the number of tokens.\n",
    "\n",
    "------\n",
    "\n",
    "## BERT has the limit of 512 tokens.\n",
    "\n",
    "Normally, for longer sequences, you just truncate to 512 tokens.\n",
    "\n",
    "The limit is derived from the positional embeddings in the Transformer architecture, for which a maximum length needs to be imposed.\n",
    "\n",
    "The magnitude of such a size is related to the amount of memory needed to handle texts: attention layers scale quadratically with the sequence length, which poses a problem with long texts."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences = [\n",
    "    \"It had been sixteen days since the zombies first attacked.\",\n",
    "    \n",
    "    \"When confronted with a rotary dial phone the teenager was perplexed.\",\n",
    "    \"His confidence would have bee admirable if it wasn't for his stupidity.\",\n",
    "    \"I'm confused: when people ask me what's up, and I point, they groan.\",\n",
    "    \"They called out her name time and again, but were met with nothing but silence.\",\n",
    "    \"After the last zombie attack sixteen days back, they are taking control of the city\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModel\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load model from HuggingFace Hub\n",
    "tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')\n",
    "model = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['input_ids', 'attention_mask'])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example_sentence = 'its a sunny morning'\n",
    "example_tokens = tokenizer.encode_plus(example_sentence, max_length = 512, truncation = True, padding = 'max_length', return_tensors = 'pt' )\n",
    "example_tokens.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens = {'input_ids': [], 'attention_mask': [] }\n",
    "\n",
    "for sentence in sentences:\n",
    "    new_tokens = tokenizer.encode_plus(sentence, max_length = 512, truncation = True, padding = 'max_length', return_tensors = 'pt' )\n",
    "    tokens['input_ids'].append(new_tokens['input_ids'][0])\n",
    "    tokens['attention_mask'].append(new_tokens['attention_mask'][0])\n",
    "\n",
    "# restructure a list of tensors into single tensor\n",
    "tokens['input_ids'] = torch.stack(tokens['input_ids'])\n",
    "tokens['attention_mask'] = torch.stack(tokens['attention_mask'])  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([6, 512])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokens['input_ids'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BaseModelOutputWithPooling(last_hidden_state=tensor([[[-1.6771e-02,  2.7875e-02,  7.9733e-02,  ...,  8.5568e-02,\n",
       "          -1.9001e-01,  3.5048e-02],\n",
       "         [ 1.0377e-01,  1.3945e-01,  4.4118e-02,  ...,  1.5550e-01,\n",
       "          -2.6237e-01,  2.6015e-02],\n",
       "         [ 7.1797e-02,  9.0426e-02,  3.3001e-02,  ...,  9.4124e-02,\n",
       "          -2.1211e-01,  2.2209e-02],\n",
       "         ...,\n",
       "         [-4.7748e-03,  9.3277e-02,  6.8232e-02,  ...,  2.2988e-01,\n",
       "          -1.6626e-01,  1.0913e-01],\n",
       "         [-4.7748e-03,  9.3277e-02,  6.8232e-02,  ...,  2.2988e-01,\n",
       "          -1.6626e-01,  1.0913e-01],\n",
       "         [-4.7748e-03,  9.3277e-02,  6.8232e-02,  ...,  2.2988e-01,\n",
       "          -1.6626e-01,  1.0913e-01]],\n",
       "\n",
       "        [[ 1.2549e-01, -1.0631e-01,  3.6856e-02,  ...,  6.0813e-02,\n",
       "          -9.1067e-02,  3.6397e-02],\n",
       "         [ 1.5688e-01, -2.4870e-01,  4.6818e-03,  ...,  8.2096e-02,\n",
       "          -1.8744e-01,  7.9757e-02],\n",
       "         [ 1.9245e-01, -1.4947e-01,  6.3418e-02,  ...,  1.9506e-01,\n",
       "          -8.3352e-02,  5.4228e-02],\n",
       "         ...,\n",
       "         [ 1.3306e-01,  4.4404e-02,  5.6143e-02,  ...,  1.1645e-01,\n",
       "          -4.1378e-02,  7.1716e-02],\n",
       "         [ 1.3306e-01,  4.4404e-02,  5.6143e-02,  ...,  1.1645e-01,\n",
       "          -4.1378e-02,  7.1716e-02],\n",
       "         [ 1.3306e-01,  4.4404e-02,  5.6143e-02,  ...,  1.1645e-01,\n",
       "          -4.1378e-02,  7.1716e-02]],\n",
       "\n",
       "        [[-9.0273e-02,  1.1550e-01,  2.5365e-02,  ..., -1.3713e-01,\n",
       "           5.5565e-02, -2.0296e-02],\n",
       "         [-8.7704e-02,  1.1863e-01,  2.8507e-02,  ..., -1.0713e-01,\n",
       "           1.8199e-01,  4.8593e-03],\n",
       "         [ 5.3350e-02,  4.9334e-01,  5.8561e-02,  ..., -1.2538e-01,\n",
       "          -3.3725e-02,  1.1508e-01],\n",
       "         ...,\n",
       "         [-1.0453e-02,  2.6995e-01,  5.4801e-02,  ..., -1.3884e-03,\n",
       "           2.8404e-02,  2.3292e-02],\n",
       "         [-1.0453e-02,  2.6995e-01,  5.4801e-02,  ..., -1.3884e-03,\n",
       "           2.8404e-02,  2.3292e-02],\n",
       "         [-1.0453e-02,  2.6995e-01,  5.4801e-02,  ..., -1.3884e-03,\n",
       "           2.8404e-02,  2.3292e-02]],\n",
       "\n",
       "        [[ 5.0962e-02, -2.5233e-01, -1.8100e-02,  ...,  5.9566e-02,\n",
       "           8.7379e-02, -1.5808e-01],\n",
       "         [ 7.6241e-02, -1.4975e-01,  3.0298e-04,  ...,  7.8922e-02,\n",
       "           9.9042e-02, -1.1577e-01],\n",
       "         [ 1.1795e-01, -1.5593e-01, -4.2907e-02,  ...,  8.0975e-02,\n",
       "           6.0451e-02, -1.7704e-01],\n",
       "         ...,\n",
       "         [ 8.7943e-02,  3.4247e-02,  1.8712e-02,  ...,  8.1033e-02,\n",
       "           6.7395e-02, -1.4431e-01],\n",
       "         [ 8.7943e-02,  3.4247e-02,  1.8712e-02,  ...,  8.1033e-02,\n",
       "           6.7395e-02, -1.4431e-01],\n",
       "         [ 8.7943e-02,  3.4247e-02,  1.8712e-02,  ...,  8.1033e-02,\n",
       "           6.7395e-02, -1.4431e-01]],\n",
       "\n",
       "        [[ 1.3064e-01,  9.8703e-02,  3.3572e-03,  ...,  6.5254e-02,\n",
       "          -3.3477e-02,  6.7913e-02],\n",
       "         [ 2.7716e-01,  1.7462e-01,  1.5868e-02,  ...,  9.0645e-02,\n",
       "          -3.9072e-02,  6.2788e-02],\n",
       "         [ 2.8550e-01, -1.8783e-01, -4.7557e-03,  ...,  1.3484e-01,\n",
       "          -1.4687e-01,  3.7743e-02],\n",
       "         ...,\n",
       "         [ 1.9483e-01,  1.2496e-01, -3.3830e-02,  ...,  1.4517e-01,\n",
       "          -7.7849e-02,  1.3236e-01],\n",
       "         [ 1.9483e-01,  1.2496e-01, -3.3830e-02,  ...,  1.4517e-01,\n",
       "          -7.7849e-02,  1.3236e-01],\n",
       "         [ 1.9483e-01,  1.2496e-01, -3.3830e-02,  ...,  1.4517e-01,\n",
       "          -7.7849e-02,  1.3236e-01]],\n",
       "\n",
       "        [[-1.0382e-01,  9.5768e-02,  1.1970e-01,  ...,  4.1555e-02,\n",
       "           4.1617e-03, -1.6580e-02],\n",
       "         [ 4.8082e-03,  1.8655e-01,  8.2883e-02,  ...,  7.7249e-02,\n",
       "          -8.7127e-02, -2.1065e-02],\n",
       "         [ 7.5384e-02,  1.0104e-01,  5.0913e-02,  ...,  1.4006e-01,\n",
       "          -6.5929e-02, -5.5219e-02],\n",
       "         ...,\n",
       "         [ 2.1896e-02,  1.4125e-01,  7.7325e-02,  ...,  1.3516e-01,\n",
       "          -1.4167e-02,  2.9171e-02],\n",
       "         [ 2.1896e-02,  1.4125e-01,  7.7325e-02,  ...,  1.3516e-01,\n",
       "          -1.4167e-02,  2.9171e-02],\n",
       "         [ 2.1896e-02,  1.4125e-01,  7.7325e-02,  ...,  1.3516e-01,\n",
       "          -1.4167e-02,  2.9171e-02]]], grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[ 0.0292,  0.0679, -0.0471,  ..., -0.0193,  0.0873, -0.0286],\n",
       "        [-0.0788,  0.0187, -0.0788,  ..., -0.0337, -0.0269,  0.0114],\n",
       "        [-0.0389, -0.0091, -0.1164,  ..., -0.0590, -0.0545, -0.1179],\n",
       "        [ 0.0195, -0.0470, -0.0454,  ...,  0.0446,  0.1281, -0.1349],\n",
       "        [ 0.0182, -0.0183,  0.0335,  ...,  0.0306, -0.0680, -0.0026],\n",
       "        [ 0.0558,  0.0283, -0.0429,  ..., -0.0109,  0.1169,  0.0399]],\n",
       "       grad_fn=<TanhBackward0>), hidden_states=None, attentions=None)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs = model(**tokens)\n",
    "outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "odict_keys(['last_hidden_state', 'pooler_output'])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "outputs.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-1.6771e-02,  2.7875e-02,  7.9733e-02,  ...,  8.5568e-02,\n",
       "          -1.9001e-01,  3.5048e-02],\n",
       "         [ 1.0377e-01,  1.3945e-01,  4.4118e-02,  ...,  1.5550e-01,\n",
       "          -2.6237e-01,  2.6015e-02],\n",
       "         [ 7.1797e-02,  9.0426e-02,  3.3001e-02,  ...,  9.4124e-02,\n",
       "          -2.1211e-01,  2.2209e-02],\n",
       "         ...,\n",
       "         [-4.7748e-03,  9.3277e-02,  6.8232e-02,  ...,  2.2988e-01,\n",
       "          -1.6626e-01,  1.0913e-01],\n",
       "         [-4.7748e-03,  9.3277e-02,  6.8232e-02,  ...,  2.2988e-01,\n",
       "          -1.6626e-01,  1.0913e-01],\n",
       "         [-4.7748e-03,  9.3277e-02,  6.8232e-02,  ...,  2.2988e-01,\n",
       "          -1.6626e-01,  1.0913e-01]],\n",
       "\n",
       "        [[ 1.2549e-01, -1.0631e-01,  3.6856e-02,  ...,  6.0813e-02,\n",
       "          -9.1067e-02,  3.6397e-02],\n",
       "         [ 1.5688e-01, -2.4870e-01,  4.6818e-03,  ...,  8.2096e-02,\n",
       "          -1.8744e-01,  7.9757e-02],\n",
       "         [ 1.9245e-01, -1.4947e-01,  6.3418e-02,  ...,  1.9506e-01,\n",
       "          -8.3352e-02,  5.4228e-02],\n",
       "         ...,\n",
       "         [ 1.3306e-01,  4.4404e-02,  5.6143e-02,  ...,  1.1645e-01,\n",
       "          -4.1378e-02,  7.1716e-02],\n",
       "         [ 1.3306e-01,  4.4404e-02,  5.6143e-02,  ...,  1.1645e-01,\n",
       "          -4.1378e-02,  7.1716e-02],\n",
       "         [ 1.3306e-01,  4.4404e-02,  5.6143e-02,  ...,  1.1645e-01,\n",
       "          -4.1378e-02,  7.1716e-02]],\n",
       "\n",
       "        [[-9.0273e-02,  1.1550e-01,  2.5365e-02,  ..., -1.3713e-01,\n",
       "           5.5565e-02, -2.0296e-02],\n",
       "         [-8.7704e-02,  1.1863e-01,  2.8507e-02,  ..., -1.0713e-01,\n",
       "           1.8199e-01,  4.8593e-03],\n",
       "         [ 5.3350e-02,  4.9334e-01,  5.8561e-02,  ..., -1.2538e-01,\n",
       "          -3.3725e-02,  1.1508e-01],\n",
       "         ...,\n",
       "         [-1.0453e-02,  2.6995e-01,  5.4801e-02,  ..., -1.3884e-03,\n",
       "           2.8404e-02,  2.3292e-02],\n",
       "         [-1.0453e-02,  2.6995e-01,  5.4801e-02,  ..., -1.3884e-03,\n",
       "           2.8404e-02,  2.3292e-02],\n",
       "         [-1.0453e-02,  2.6995e-01,  5.4801e-02,  ..., -1.3884e-03,\n",
       "           2.8404e-02,  2.3292e-02]],\n",
       "\n",
       "        [[ 5.0962e-02, -2.5233e-01, -1.8100e-02,  ...,  5.9566e-02,\n",
       "           8.7379e-02, -1.5808e-01],\n",
       "         [ 7.6241e-02, -1.4975e-01,  3.0298e-04,  ...,  7.8922e-02,\n",
       "           9.9042e-02, -1.1577e-01],\n",
       "         [ 1.1795e-01, -1.5593e-01, -4.2907e-02,  ...,  8.0975e-02,\n",
       "           6.0451e-02, -1.7704e-01],\n",
       "         ...,\n",
       "         [ 8.7943e-02,  3.4247e-02,  1.8712e-02,  ...,  8.1033e-02,\n",
       "           6.7395e-02, -1.4431e-01],\n",
       "         [ 8.7943e-02,  3.4247e-02,  1.8712e-02,  ...,  8.1033e-02,\n",
       "           6.7395e-02, -1.4431e-01],\n",
       "         [ 8.7943e-02,  3.4247e-02,  1.8712e-02,  ...,  8.1033e-02,\n",
       "           6.7395e-02, -1.4431e-01]],\n",
       "\n",
       "        [[ 1.3064e-01,  9.8703e-02,  3.3572e-03,  ...,  6.5254e-02,\n",
       "          -3.3477e-02,  6.7913e-02],\n",
       "         [ 2.7716e-01,  1.7462e-01,  1.5868e-02,  ...,  9.0645e-02,\n",
       "          -3.9072e-02,  6.2788e-02],\n",
       "         [ 2.8550e-01, -1.8783e-01, -4.7557e-03,  ...,  1.3484e-01,\n",
       "          -1.4687e-01,  3.7743e-02],\n",
       "         ...,\n",
       "         [ 1.9483e-01,  1.2496e-01, -3.3830e-02,  ...,  1.4517e-01,\n",
       "          -7.7849e-02,  1.3236e-01],\n",
       "         [ 1.9483e-01,  1.2496e-01, -3.3830e-02,  ...,  1.4517e-01,\n",
       "          -7.7849e-02,  1.3236e-01],\n",
       "         [ 1.9483e-01,  1.2496e-01, -3.3830e-02,  ...,  1.4517e-01,\n",
       "          -7.7849e-02,  1.3236e-01]],\n",
       "\n",
       "        [[-1.0382e-01,  9.5768e-02,  1.1970e-01,  ...,  4.1555e-02,\n",
       "           4.1617e-03, -1.6580e-02],\n",
       "         [ 4.8082e-03,  1.8655e-01,  8.2883e-02,  ...,  7.7249e-02,\n",
       "          -8.7127e-02, -2.1065e-02],\n",
       "         [ 7.5384e-02,  1.0104e-01,  5.0913e-02,  ...,  1.4006e-01,\n",
       "          -6.5929e-02, -5.5219e-02],\n",
       "         ...,\n",
       "         [ 2.1896e-02,  1.4125e-01,  7.7325e-02,  ...,  1.3516e-01,\n",
       "          -1.4167e-02,  2.9171e-02],\n",
       "         [ 2.1896e-02,  1.4125e-01,  7.7325e-02,  ...,  1.3516e-01,\n",
       "          -1.4167e-02,  2.9171e-02],\n",
       "         [ 2.1896e-02,  1.4125e-01,  7.7325e-02,  ...,  1.3516e-01,\n",
       "          -1.4167e-02,  2.9171e-02]]], grad_fn=<NativeLayerNormBackward0>)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embeddings = outputs.last_hidden_state\n",
    "embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([6, 512, 768])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embeddings.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "-------------\n",
    "\n",
    "## What is the basic concept behind Pooling \n",
    "\n",
    "Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding. This layer also allows to use the CLS token if it is returned by the underlying word embedding model. You can concatenate multiple poolings together.\n",
    "\n",
    "### We can also consider the last hidden state of shape [batch_size, seq_len, hidden_state], the average across seq_len dimensions to get averaged/mean embeddings.\n",
    "\n",
    "## Mean Pooling\n",
    "\n",
    "BERT produces contextualized word embeddings for all input tokens in our text. As we want a fixed-sized output representation (vector u), we need a pooling layer. Different pooling options are available, the most basic one is mean-pooling: We simply average all contextualized word embeddings BERT is giving us. This gives us a fixed 768 dimensional output vector and which is independent of how long our input text was.\n",
    "\n",
    "---------\n",
    "\n",
    "## From Sentence Transformer to Cosine Similarity Calculation\n",
    "\n",
    "### For us to transform our **`last_hidden_states`** tensor into our desired vector — we use a mean pooling method.\n",
    "\n",
    "### Each of these 512 tokens has separate 768 values. This pooling work will take the average of all token embeddings and consolidate them into a unique 768 vector space, producing a ‘sentence vector’.\n",
    "\n",
    "And thats why => After we have produced our dense vectors `embeddings`, we need to perform a *mean pooling* operation on them to create a single vector encoding (the **sentence embedding**). \n",
    "\n",
    "--------------\n",
    "\n",
    "### To do this mean pooling operation we will need to multiply each value in our `embeddings` tensor by it's respective `attention_mask` value - so that we can mask or ignore non-real tokens. you should only take into account those tokens which are not padding tokens if you want to average them.\n",
    "\n",
    "To perform this operation, we first resize our `attention_mask` tensor:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([6, 512])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "attention_mask = tokens['attention_mask']\n",
    "attention_mask.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([6, 512, 768])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "resized_attention_mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()\n",
    "resized_attention_mask.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         ...,\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.]],\n",
       "\n",
       "        [[1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         ...,\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.]],\n",
       "\n",
       "        [[1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         ...,\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.]],\n",
       "\n",
       "        [[1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         ...,\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.]],\n",
       "\n",
       "        [[1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         ...,\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.]],\n",
       "\n",
       "        [[1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         ...,\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.]]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "resized_attention_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([768])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "resized_attention_mask[0][0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([6, 512, 768])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "masked_embedding = embeddings * resized_attention_mask\n",
    "masked_embedding.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-1.6771e-02,  2.7875e-02,  7.9733e-02,  ...,  8.5568e-02,\n",
       "          -1.9001e-01,  3.5048e-02],\n",
       "         [ 1.0377e-01,  1.3945e-01,  4.4118e-02,  ...,  1.5550e-01,\n",
       "          -2.6237e-01,  2.6015e-02],\n",
       "         [ 7.1797e-02,  9.0426e-02,  3.3001e-02,  ...,  9.4124e-02,\n",
       "          -2.1211e-01,  2.2209e-02],\n",
       "         ...,\n",
       "         [-0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n",
       "          -0.0000e+00,  0.0000e+00],\n",
       "         [-0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n",
       "          -0.0000e+00,  0.0000e+00],\n",
       "         [-0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n",
       "          -0.0000e+00,  0.0000e+00]],\n",
       "\n",
       "        [[ 1.2549e-01, -1.0631e-01,  3.6856e-02,  ...,  6.0813e-02,\n",
       "          -9.1067e-02,  3.6397e-02],\n",
       "         [ 1.5688e-01, -2.4870e-01,  4.6818e-03,  ...,  8.2096e-02,\n",
       "          -1.8744e-01,  7.9757e-02],\n",
       "         [ 1.9245e-01, -1.4947e-01,  6.3418e-02,  ...,  1.9506e-01,\n",
       "          -8.3352e-02,  5.4228e-02],\n",
       "         ...,\n",
       "         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n",
       "          -0.0000e+00,  0.0000e+00],\n",
       "         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n",
       "          -0.0000e+00,  0.0000e+00],\n",
       "         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n",
       "          -0.0000e+00,  0.0000e+00]],\n",
       "\n",
       "        [[-9.0273e-02,  1.1550e-01,  2.5365e-02,  ..., -1.3713e-01,\n",
       "           5.5565e-02, -2.0296e-02],\n",
       "         [-8.7704e-02,  1.1863e-01,  2.8507e-02,  ..., -1.0713e-01,\n",
       "           1.8199e-01,  4.8593e-03],\n",
       "         [ 5.3350e-02,  4.9334e-01,  5.8561e-02,  ..., -1.2538e-01,\n",
       "          -3.3725e-02,  1.1508e-01],\n",
       "         ...,\n",
       "         [-0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,\n",
       "           0.0000e+00,  0.0000e+00],\n",
       "         [-0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,\n",
       "           0.0000e+00,  0.0000e+00],\n",
       "         [-0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,\n",
       "           0.0000e+00,  0.0000e+00]],\n",
       "\n",
       "        [[ 5.0962e-02, -2.5233e-01, -1.8100e-02,  ...,  5.9566e-02,\n",
       "           8.7379e-02, -1.5808e-01],\n",
       "         [ 7.6241e-02, -1.4975e-01,  3.0298e-04,  ...,  7.8922e-02,\n",
       "           9.9042e-02, -1.1577e-01],\n",
       "         [ 1.1795e-01, -1.5593e-01, -4.2907e-02,  ...,  8.0975e-02,\n",
       "           6.0451e-02, -1.7704e-01],\n",
       "         ...,\n",
       "         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n",
       "           0.0000e+00, -0.0000e+00],\n",
       "         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n",
       "           0.0000e+00, -0.0000e+00],\n",
       "         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n",
       "           0.0000e+00, -0.0000e+00]],\n",
       "\n",
       "        [[ 1.3064e-01,  9.8703e-02,  3.3572e-03,  ...,  6.5254e-02,\n",
       "          -3.3477e-02,  6.7913e-02],\n",
       "         [ 2.7716e-01,  1.7462e-01,  1.5868e-02,  ...,  9.0645e-02,\n",
       "          -3.9072e-02,  6.2788e-02],\n",
       "         [ 2.8550e-01, -1.8783e-01, -4.7557e-03,  ...,  1.3484e-01,\n",
       "          -1.4687e-01,  3.7743e-02],\n",
       "         ...,\n",
       "         [ 0.0000e+00,  0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,\n",
       "          -0.0000e+00,  0.0000e+00],\n",
       "         [ 0.0000e+00,  0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,\n",
       "          -0.0000e+00,  0.0000e+00],\n",
       "         [ 0.0000e+00,  0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,\n",
       "          -0.0000e+00,  0.0000e+00]],\n",
       "\n",
       "        [[-1.0382e-01,  9.5768e-02,  1.1970e-01,  ...,  4.1555e-02,\n",
       "           4.1617e-03, -1.6580e-02],\n",
       "         [ 4.8082e-03,  1.8655e-01,  8.2883e-02,  ...,  7.7249e-02,\n",
       "          -8.7127e-02, -2.1065e-02],\n",
       "         [ 7.5384e-02,  1.0104e-01,  5.0913e-02,  ...,  1.4006e-01,\n",
       "          -6.5929e-02, -5.5219e-02],\n",
       "         ...,\n",
       "         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n",
       "          -0.0000e+00,  0.0000e+00],\n",
       "         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n",
       "          -0.0000e+00,  0.0000e+00],\n",
       "         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,\n",
       "          -0.0000e+00,  0.0000e+00]]], grad_fn=<MulBackward0>)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "masked_embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([6, 768])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "summed_masked_embeddings = torch.sum(masked_embedding, 1)\n",
    "summed_masked_embeddings.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.5913,  1.1332,  0.5017,  ...,  1.8387, -3.1667,  0.8035],\n",
       "        [ 2.5646, -1.7509,  0.7735,  ...,  1.0171, -1.7503,  1.2193],\n",
       "        [-0.5600,  3.2766,  0.5741,  ..., -2.1975,  0.9449,  0.7661],\n",
       "        [ 1.5776, -3.6690, -0.8673,  ...,  2.7474,  2.0530, -3.2321],\n",
       "        [ 4.1955,  0.8502, -0.0662,  ...,  2.1655, -2.1070,  1.7098],\n",
       "        [-0.0775,  2.2288,  1.7956,  ...,  1.4763, -0.8778, -0.6049]],\n",
       "       grad_fn=<SumBackward1>)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "summed_masked_embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         ...,\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.]],\n",
       "\n",
       "        [[1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         ...,\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.]],\n",
       "\n",
       "        [[1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         ...,\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.]],\n",
       "\n",
       "        [[1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         ...,\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.]],\n",
       "\n",
       "        [[1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         ...,\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.]],\n",
       "\n",
       "        [[1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "         ...,\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "         [0., 0., 0.,  ..., 0., 0., 0.]]])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "resized_attention_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([6, 768])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "count_of_one_in_mask_tensor = torch.clamp(resized_attention_mask.sum(1), min=1e-9 )\n",
    "\n",
    "count_of_one_in_mask_tensor.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[13., 13., 13.,  ..., 13., 13., 13.],\n",
       "        [16., 16., 16.,  ..., 16., 16., 16.],\n",
       "        [19., 19., 19.,  ..., 19., 19., 19.],\n",
       "        [23., 23., 23.,  ..., 23., 23., 23.],\n",
       "        [19., 19., 19.,  ..., 19., 19., 19.],\n",
       "        [18., 18., 18.,  ..., 18., 18., 18.]])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "count_of_one_in_mask_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([6, 768])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "summed_masked_embeddings.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([6, 768])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "count_of_one_in_mask_tensor.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_pooled = summed_masked_embeddings / count_of_one_in_mask_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([6, 768])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mean_pooled.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Calculate cosine similarity for sentence `0`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.12543598,  0.05873729, -0.00801761,  0.20136353,  0.6832719 ]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "\n",
    "mean_pooled = mean_pooled.detach().numpy()\n",
    "\n",
    "cosine_similarity([mean_pooled[0]], mean_pooled[1:] )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.13 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "f9f85f796d01129d0dd105a088854619f454435301f6ffec2fea96ecbd9be4ac"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
